{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🌟 Hands-on Practice with Function Call\n",
    "\n",
    "Next, we'll implement a standalone function call example using the Yi-1.5-9B-Chat model. This example doesn't rely on any specific framework but directly uses the Hugging Face transformers library to load and use the model.\n",
    "\n",
    "First, let's install the necessary dependencies:\n",
    "⚠️ Please be mindful of your computer's GPU memory here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install transformers torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's start building our function call implementation step by step.\n",
    "Just like before, we'll use addition, subtraction, and multiplication functions as examples.\n",
    "#### Step 1: Import necessary libraries and define functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "# Define available functions\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    return a * b\n",
    "\n",
    "def plus(a: int, b: int) -> int:\n",
    "    return a + b\n",
    "\n",
    "def minus(a: int, b: int) -> int:\n",
    "    return a - b\n",
    "# You can add your own functions here if needed\n",
    "# Function mapping\n",
    "available_functions = {\n",
    "    \"multiply\": multiply,\n",
    "    \"plus\": plus,\n",
    "    \"minus\": minus\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 2: Load the Yi model and tokenizer (we'll use transformers for this step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Yi model and tokenizer\n",
    "model_path = \"01-ai/Yi-1.5-9B-Chat\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=\"auto\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 3: Implement functionality to generate responses and parse function calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Yi's response\n",
    "def generate_response(prompt):\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "    outputs = model.generate(**inputs, temperature=0.7, top_p=0.95)\n",
    "    response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "    return response.split(\"Human:\")[0].strip()\n",
    "\n",
    "# Parse output, extract function information\n",
    "def parse_function_call(response):\n",
    "    try:\n",
    "        # Try to extract JSON-formatted function call from the response\n",
    "        start = response.index(\"{\")\n",
    "        end = response.rindex(\"}\") + 1\n",
    "        function_call_json = response[start:end]\n",
    "        function_call = json.loads(function_call_json)\n",
    "        return function_call\n",
    "    except (ValueError, json.JSONDecodeError):\n",
    "        return None\n",
    "\n",
    "# Execute function call (bridge for function information transfer)\n",
    "def execute_function(function_name: str, arguments: dict) -> Any:\n",
    "    if function_name in available_functions:\n",
    "        return available_functions[function_name](**arguments)\n",
    "    else:\n",
    "        raise ValueError(f\"Function {function_name} not found\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 4: Implement the main loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Main loop\n",
    "def main():\n",
    "    # This prompt template contains function information. If you need to add other functions, you should synchronize them with the large model here as well\n",
    "    system_prompt = \"\"\"You are an AI assistant capable of calling functions to perform tasks. When a user asks a question that requires calling a function, respond with a JSON object containing the function name and arguments. Available functions are:\n",
    "    - multiply(a: int, b: int) -> int: Multiplies two integers\n",
    "    - plus(a: int, b: int) -> int: Adds two integers\n",
    "    - minus(a: int, b: int) -> int: Subtracts two integers\n",
    "    For example, if the user asks \"What is 5 plus 3?\", respond with:\n",
    "    {\"function\": \"plus\", \"arguments\": {\"a\": 5, \"b\": 3}}\n",
    "    If no function call is needed, respond normally.\"\"\"\n",
    "\n",
    "    conversation_history = [f\"System: {system_prompt}\"]\n",
    "\n",
    "    while True:\n",
    "        user_input = input(\"Human: \")\n",
    "        if user_input.lower() == 'exit':\n",
    "            break\n",
    "\n",
    "        # Add user input to conversation history\n",
    "        conversation_history.append(f\"Human: {user_input}\")\n",
    "        \n",
    "        # Build complete prompt\n",
    "        full_prompt = \"\\n\".join(conversation_history) + \"\\nAssistant:\"\n",
    "\n",
    "        # Get model response\n",
    "        response = generate_response(full_prompt)\n",
    "        print(f\"Model response: {response}\")\n",
    "\n",
    "        # Parse possible function call\n",
    "        function_call = parse_function_call(response)\n",
    "\n",
    "        if function_call:\n",
    "            function_name = function_call[\"function\"]\n",
    "            arguments = function_call[\"arguments\"]\n",
    "\n",
    "            try:\n",
    "                # Execute function\n",
    "                result = execute_function(function_name, arguments)\n",
    "\n",
    "                # Add result to conversation history\n",
    "                conversation_history.append(f\"Assistant: The result of {function_name}({arguments}) is {result}\")\n",
    "                print(f\"Assistant: The result of {function_name}({arguments}) is {result}\")\n",
    "            except Exception as e:\n",
    "                error_message = f\"An error occurred: {str(e)}\"\n",
    "                conversation_history.append(f\"Assistant: {error_message}\")\n",
    "                print(f\"Assistant: {error_message}\")\n",
    "        else:\n",
    "            # If no function call, directly add model's response to conversation history\n",
    "            conversation_history.append(f\"Assistant: {response}\")\n",
    "            print(f\"Assistant: {response}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Running the Example\n",
    "\n",
    "To run this example, you need to ensure that you've downloaded the Yi-1.5-9B-Chat model, or change the `model_path` to the actual path of the model. Then, you can directly run this Python script.\n",
    "\n",
    "Usage example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Human: What is 5 plus 3?\n",
    "Model response: Here's the function call to perform the addition:\n",
    "{\"function\": \"plus\", \"arguments\": {\"a\": 5, \"b\": 3}}\n",
    "Assistant: The result of plus({'a': 5, 'b': 3}) is 8"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
